{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# OpenAI Tool calling capabilities with Databricks Unity Catalog\n",
    "\n",
    "Import and run this notebook in Databricks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install unitycatalog-openai[databricks] mlflow langchain_openai\n",
    "%restart_python"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "18072ae8-83e5-4880-9b89-a92ee9f67641",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a DatabricksFunctionClient\n",
    "\n",
    "Creating a client that uses a serverless connection endpoint allows for Python function CRUD operations within UC as well as function execution capabilities."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "fae8cd2a-6f52-4da2-8dbb-4d95d8284e28",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "from unitycatalog.ai.core.base import set_uc_function_client\n",
    "from unitycatalog.ai.core.databricks import DatabricksFunctionClient\n",
    "\n",
    "client = DatabricksFunctionClient()\n",
    "\n",
    "# set the default uc function client, it will be used for all toolkits\n",
    "set_uc_function_client(client)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "d5a6f448-21a5-4fe4-82e3-f952df401834",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a UC function for executing python code"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "9509aeb8-d4c7-41b9-af8f-473ffdd5379c",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "# replace with your own catalog and schema\n",
    "CATALOG = \"ml\"\n",
    "SCHEMA = \"serena_test\"\n",
    "\n",
    "\n",
    "def execute_python_code(code: str) -> str:\n",
    "    \"\"\"\n",
    "    Executes the given python code and returns its stdout.\n",
    "\n",
    "    Args:\n",
    "      code: Python code to execute. Remember to print the final result to stdout.\n",
    "    \"\"\"\n",
    "    import sys\n",
    "    from io import StringIO\n",
    "\n",
    "    stdout = StringIO()\n",
    "    sys.stdout = stdout\n",
    "    exec(code)\n",
    "    return stdout.getvalue()\n",
    "\n",
    "\n",
    "function_info = client.create_python_function(\n",
    "    func=execute_python_code, catalog=CATALOG, schema=SCHEMA, replace=True\n",
    ")\n",
    "python_execution_function_name = function_info.full_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "47c1b20e-96d2-4962-afd3-afbc62b705f5",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='2\\n', truncated=None)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.execute_function(python_execution_function_name, {\"code\": \"print(1+1)\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "944e5286-3e4c-42d4-8ac3-c061209d269e",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create a UC function for translation\n",
    "\n",
    "Use Databricks built-in [ai_translate](https://docs.databricks.com/en/sql/language-manual/functions/ai_translate.html) function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "da1b8345-e924-4eb1-8820-a9920514fc7f",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionInfo(browse_only=None, catalog_name='ml', comment=None, created_at=1729136621857, created_by='serena.ruan@databricks.com', data_type=<ColumnTypeName.STRING: 'STRING'>, external_language=None, external_name=None, full_data_type='STRING', full_name='ml.serena_test.translate', function_id='423d49b1-8c8b-4a14-a6c7-eed1e3a23afc', input_params=FunctionParameterInfos(parameters=[FunctionParameterInfo(name='content', type_text='string', type_name=<ColumnTypeName.STRING: 'STRING'>, position=0, comment='the text to be translated', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{\"name\":\"content\",\"type\":\"string\",\"nullable\":true,\"metadata\":{\"comment\":\"the text to be translated\"}}', type_precision=0, type_scale=0), FunctionParameterInfo(name='language', type_text='string', type_name=<ColumnTypeName.STRING: 'STRING'>, position=1, comment='the target language code to translate the content to', parameter_default=None, parameter_mode=None, parameter_type=<FunctionParameterType.PARAM: 'PARAM'>, type_interval_type=None, type_json='{\"name\":\"language\",\"type\":\"string\",\"nullable\":true,\"metadata\":{\"comment\":\"the target language code to translate the content to\"}}', type_precision=0, type_scale=0)]), is_deterministic=True, is_null_call=None, metastore_id='19a85dee-54bc-43a2-87ab-023d0ec16013', name='translate', owner='serena.ruan@databricks.com', parameter_style=<FunctionInfoParameterStyle.S: 'S'>, properties='{\"sqlConfig.spark.sql.ansi.enabled\":\"true\",\"referredTempFunctionsNames\":\"[]\",\"sqlConfig.spark.sql.streaming.statefulOperator.stateRebalancing.enabled\":\"false\",\"catalogAndNamespace.part.0\":\"main\",\"sqlConfig.spark.sql.legacy.createHiveTableByDefault\":\"false\",\"sqlConfig.spark.sql.shuffleDependency.skipMigration.enabled\":\"true\",\"sqlConfig.spark.sql.streaming.stopTimeout\":\"15s\",\"referredTempViewNames\":\"[]\",\"catalogAndNamespace.part.1\":\"default\",\"sqlConfig.spark.sql.readSideCharPadding\":\"true\",\"sqlConfig.spark.sql.variable.substitute\":\"false\",\"sqlConfig.spark.databricks.sql.functions.aiForecast.enabled\":\"true\",\"sqlConfig.spark.sql.sources.default\":\"delta\",\"sqlConfig.spark.sql.hive.convertCTAS\":\"true\",\"referredTempVariableNames\":\"[]\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOnSocketTimeoutException\":\"true\",\"sqlConfig.spark.sql.sources.commitProtocolClass\":\"com.databricks.sql.transaction.directory.DirectoryAtomicCommitProtocol\",\"sqlConfig.spark.sql.functions.remoteHttpClient.retryOn400TimeoutError\":\"true\",\"catalogAndNamespace.numParts\":\"2\",\"sqlConfig.spark.sql.stableDerivedColumnAlias.enabled\":\"true\",\"sqlConfig.spark.sql.parquet.compression.codec\":\"snappy\",\"sqlConfig.spark.sql.streaming.stateStore.providerClass\":\"com.databricks.sql.streaming.state.RocksDBStateStoreProvider\"}', return_params=None, routine_body=<FunctionInfoRoutineBody.SQL: 'SQL'>, routine_definition='(SELECT ai_translate(content, language))', routine_dependencies=None, schema_name='serena_test', security_type=<FunctionInfoSecurityType.DEFINER: 'DEFINER'>, specific_name='translate', sql_data_access=<FunctionInfoSqlDataAccess.CONTAINS_SQL: 'CONTAINS_SQL'>, sql_path=None, updated_at=1729136621857, updated_by='serena.ruan@databricks.com')"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "translate_function_name = f\"{CATALOG}.{SCHEMA}.translate\"\n",
    "sql_body = f\"\"\"CREATE OR REPLACE FUNCTION {translate_function_name}(content STRING COMMENT 'the text to be translated', language STRING COMMENT 'the target language code to translate the content to')\n",
    "RETURNS STRING\n",
    "RETURN SELECT ai_translate(content, language)\n",
    "\"\"\"\n",
    "client.create_function(sql_function_body=sql_body)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "045e5ae0-3b27-449b-bc84-bf7335b72b9b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FunctionExecutionResult(error=None, format='SCALAR', value='¿Qué es Databricks?', truncated=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "client.execute_function(\n",
    "    translate_function_name, {\"content\": \"What is Databricks?\", \"language\": \"es\"}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "5e61bb13-e40d-4d22-a0aa-f58367adba24",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Create tools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "952ed06f-89c6-4495-81d6-ca1fe2a24c9b",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'type': 'function',\n",
       "  'function': {'name': 'ml__serena_test__execute_python_code',\n",
       "   'strict': True,\n",
       "   'parameters': {'properties': {'code': {'anyOf': [{'type': 'string'},\n",
       "       {'type': 'null'}],\n",
       "      'description': 'Python code to execute. Remember to print the final result to stdout.',\n",
       "      'title': 'Code'}},\n",
       "    'title': 'ml__serena_test__execute_python_code__params',\n",
       "    'type': 'object',\n",
       "    'additionalProperties': False,\n",
       "    'required': ['code']},\n",
       "   'description': 'Executes the given python code and returns its stdout.'}},\n",
       " {'type': 'function',\n",
       "  'function': {'name': 'ml__serena_test__translate',\n",
       "   'strict': True,\n",
       "   'parameters': {'properties': {'content': {'anyOf': [{'type': 'string'},\n",
       "       {'type': 'null'}],\n",
       "      'description': 'the text to be translated',\n",
       "      'title': 'Content'},\n",
       "     'language': {'anyOf': [{'type': 'string'}, {'type': 'null'}],\n",
       "      'description': 'the target language code to translate the content to',\n",
       "      'title': 'Language'}},\n",
       "    'title': 'ml__serena_test__translate__params',\n",
       "    'type': 'object',\n",
       "    'additionalProperties': False,\n",
       "    'required': ['content', 'language']},\n",
       "   'description': ''}}]"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from unitycatalog.ai.openai.toolkit import UCFunctionToolkit\n",
    "\n",
    "toolkit = UCFunctionToolkit(\n",
    "    function_names=[python_execution_function_name, translate_function_name]\n",
    ")\n",
    "tools = toolkit.tools\n",
    "tools"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {},
     "inputWidgets": {},
     "nuid": "aafb2a37-27d6-42b9-8220-b71b5ab062ea",
     "showTitle": false,
     "title": ""
    }
   },
   "source": [
    "#### Use the tools in OpenAI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "d3d6a36b-42c4-480f-a446-752a11ca2bd0",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "# Replace with your own API key\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"<OPENAI_API_KEY>\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "633f6958-2010-4b54-96df-c637036fc1a6",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "ChatCompletion(id='chatcmpl-AJBnsZzsh5HVxhGPExplOSwJTwvGs', choices=[Choice(finish_reason='tool_calls', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, role='assistant', function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='call_WBDFysLam5QIPaNdD7cpLzHP', function=Function(arguments='{\"code\":\"print(2**10)\"}', name='ml__serena_test__execute_python_code'), type='function')]))], created=1729136968, model='gpt-4o-mini-2024-07-18', object='chat.completion', service_tier=None, system_fingerprint='fp_e2bde53e6e', usage=CompletionUsage(completion_tokens=26, prompt_tokens=175, total_tokens=201, completion_tokens_details=CompletionTokensDetails(audio_tokens=None, reasoning_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=None, cached_tokens=0)))\n"
     ]
    }
   ],
   "source": [
    "import openai\n",
    "\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a helpful customer support assistant. Use the supplied tools to assist the user.\",\n",
    "    },\n",
    "    {\"role\": \"user\", \"content\": \"What is the result of 2**10?\"},\n",
    "]\n",
    "response = openai.chat.completions.create(\n",
    "    model=\"gpt-4o-mini\",\n",
    "    messages=messages,\n",
    "    tools=tools,\n",
    ")\n",
    "# check the model response\n",
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "e9c3f7f6-0bdd-4861-ab98-a88a5e47e231",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'content': None, 'refusal': None, 'role': 'assistant', 'tool_calls': [{'id': 'call_WBDFysLam5QIPaNdD7cpLzHP', 'function': {'arguments': '{\"code\":\"print(2**10)\"}', 'name': 'ml__serena_test__execute_python_code'}, 'type': 'function'}]}, {'role': 'tool', 'content': '{\"content\": \"1024\\\\n\"}', 'tool_call_id': 'call_WBDFysLam5QIPaNdD7cpLzHP'}]\n"
     ]
    }
   ],
   "source": [
    "# Generate the final response using helper functions provided in the sdk\n",
    "from unitycatalog.ai.openai.utils import generate_tool_call_messages\n",
    "\n",
    "messages = generate_tool_call_messages(response=response)\n",
    "messages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "application/vnd.databricks.v1+cell": {
     "cellMetadata": {
      "byteLimit": 2048000,
      "rowLimit": 10000
     },
     "inputWidgets": {},
     "nuid": "2971a920-d117-43ac-b086-54c0e5dd78a4",
     "showTitle": false,
     "title": ""
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[{'content': None, 'refusal': None, 'role': 'assistant', 'tool_calls': [{'id': 'call_GXwEMdBD2gMvvqjvG1mwgB1f', 'function': {'arguments': '{\"content\":\"我是一个机器人\",\"language\":\"en\"}', 'name': 'ml__serena_test__translate'}, 'type': 'function'}]}, {'role': 'tool', 'content': '{\"content\": \"I am a robot\"}', 'tool_call_id': 'call_GXwEMdBD2gMvvqjvG1mwgB1f'}]\n"
     ]
    }
   ],
   "source": [
    "# Test the model can choose tools based on question\n",
    "messages = [\n",
    "    {\n",
    "        \"role\": \"system\",\n",
    "        \"content\": \"You are a helpful customer support assistant. Use the supplied tools to assist the user.\",\n",
    "    },\n",
    "    {\"role\": \"user\", \"content\": \"What is English for '我是一个机器人'?\"},\n",
    "]\n",
    "response = openai.chat.completions.create(\n",
    "    model=\"gpt-4o-mini\",\n",
    "    messages=messages,\n",
    "    tools=tools,\n",
    ")\n",
    "messages = generate_tool_call_messages(response=response)\n",
    "messages"
   ]
  }
 ],
 "metadata": {
  "application/vnd.databricks.v1+notebook": {
   "dashboards": [],
   "environmentMetadata": {
    "base_environment": "",
    "client": "1"
   },
   "language": "python",
   "notebookMetadata": {
    "mostRecentlyExecutedCommandWithImplicitDF": {
     "commandId": -1,
     "dataframes": [
      "_sqldf"
     ]
    },
    "pythonIndentUnit": 2
   },
   "notebookName": "ucai-openai sdk",
   "widgets": {}
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
